{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "toc": true
   },
   "source": [
    "<h1>Table of Contents<span class=\"tocSkip\"></span></h1>\n",
    "<div class=\"toc\" style=\"margin-top: 1em;\"><ul class=\"toc-item\"><li><span><a href=\"#Load-and-Prepare-Text8-data\" data-toc-modified-id=\"Load-and-Prepare-Text8-data-1\"><span class=\"toc-item-num\">1&nbsp;&nbsp;</span>Load and Prepare Text8 data</a></span></li></ul></div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RNN in TensorFlow for Text Data (NLP) <a class=\"tocSkip\">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "NumPy:1.13.1\n",
      "Pandas:0.21.0\n",
      "Matplotlib:2.1.0\n",
      "TensorFlow:1.4.1\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "\n",
    "import numpy as np\n",
    "np.random.seed(123)\n",
    "print(\"NumPy:{}\".format(np.__version__))\n",
    "\n",
    "import tensorflow as tf\n",
    "tf.set_random_seed(123)\n",
    "print(\"TensorFlow:{}\".format(tf.__version__))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATASETSLIB_HOME = os.path.join(os.path.expanduser('~'),'dl-ts','datasetslib')\n",
    "import sys\n",
    "if not DATASETSLIB_HOME in sys.path:\n",
    "    sys.path.append(DATASETSLIB_HOME)\n",
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "import datasetslib\n",
    "\n",
    "from datasetslib import util as dsu\n",
    "from datasetslib import nputil\n",
    "datasetslib.datasets_root = os.path.join(os.path.expanduser('~'),'datasets')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Text Generation with Text8 Data in TensorFlow"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load and Prepare Text8 data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Already exists: /home/armando/datasets/text8/text8.zip\n",
      "Train: [  8 497   7   5 116]\n",
      "Vocabulary Length =  1457\n"
     ]
    }
   ],
   "source": [
    "from datasetslib.text8 import Text8\n",
    "text8 = Text8()\n",
    "text8.load_data(clip_at=5000) # downloads data, converts words to ids, converts files to a list of ids\n",
    "print('Train:', text8.part['train'][0:5])\n",
    "print('Vocabulary Length = ',text8.vocab_len)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "anarchism originated as a term of abuse first used against early working class radicals including the diggers of the english revolution and the sans culottes of the french revolution whilst the term is still used in a pejorative way to describe any act that used violent means to destroy the organization of society it has also been taken up as a positive label by self defined anarchists the word anarchism is derived from the greek without archons ruler chief king anarchism as a political philosophy is the belief that rulers are unnecessary and should be abolished although there are differing\n"
     ]
    }
   ],
   "source": [
    "def id2string(ids):\n",
    "    return ' '.join([text8.id2word[x_i] for x_i in ids])\n",
    "print(id2string(text8.part['train'][0:100]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# clear the effects of previous sessions in the Jupyter Notebook\n",
    "tf.reset_default_graph()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# parameters\n",
    "\n",
    "batch_size = 128\n",
    "n_x = 5 # number of input words\n",
    "n_y = 1 # number of output words\n",
    "n_x_vars = 1 # in case of our text, there is only 1 variable at each timestep\n",
    "n_y_vars = text8.vocab_len\n",
    "state_size = 128\n",
    "learning_rate = 0.001\n",
    "\n",
    "x_p = tf.placeholder(tf.float32, [None, n_x, n_x_vars], name='x_p') \n",
    "y_p = tf.placeholder(tf.float32, [None, n_y_vars], name='y_p')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# make a list of tensors of length n_x\n",
    "x_in = tf.unstack(x_p,axis=1,name='x_in')\n",
    "\n",
    "# can also be done using this: \n",
    "# reshape x placeholder to [1, n_timesteps]\n",
    "# x = tf.reshape(x_p,[-1,n_x])\n",
    "# generate sequence of inputs, each input is n_timesteps long\n",
    "# x = tf.split(x,n_x,1,name='x')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "cell = tf.nn.rnn_cell.LSTMCell(state_size)\n",
    "rnn_outputs, final_states = tf.nn.static_rnn(cell, x_in,dtype=tf.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Random 5 words:  free bolshevik be n another\n",
      "First 5 words:  anarchism originated as a term\n"
     ]
    }
   ],
   "source": [
    "random5 = np.random.choice(n_x * 50, n_x, replace=False)\n",
    "print('Random 5 words: ',id2string(random5))\n",
    "first5 = text8.part['train'][0:n_x].copy()\n",
    "print('First 5 words: ',id2string(first5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# pick the last output only\n",
    "\n",
    "# output node parameters\n",
    "w = tf.get_variable('w', [state_size, n_y_vars], initializer= tf.random_normal_initializer)\n",
    "b = tf.get_variable('b', [n_y_vars], initializer=tf.constant_initializer(0.0))\n",
    "\n",
    "y_out = tf.matmul(rnn_outputs[-1], w) + b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y_out, labels=y_p))\n",
    "optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_correct_pred = tf.equal(tf.argmax(y_out,1), tf.argmax(y_p,1))\n",
    "accuracy = tf.reduce_mean(tf.cast(n_correct_pred, tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 99, Average loss:1.2983208000659943, Average accuracy:0.84375\n",
      "  Random 5 prediction: century warren own supported supported without without without without strongly\n",
      "  First 5 prediction: market argued individualist warren without without without without strongly strongly\n",
      "\n",
      "Epoch 199, Average loss:0.5452078034480413, Average accuracy:0.939453125\n",
      "  Random 5 prediction: been cnt also also also called syndicalism syndicalist operation syndicalists\n",
      "  First 5 prediction: spain force like politics ricardo key mag mag mag mag\n",
      "\n",
      "Epoch 299, Average loss:1.193410764137904, Average accuracy:0.8717447916666666\n",
      "  Random 5 prediction: tolstoy goods aiming anarchistic anarchistic anarchistic anarchistic anarchistic anarchistic anarchistic\n",
      "  First 5 prediction: their social groups groups authoritarian authoritarian authoritarian authoritarian authoritarian authoritarian\n",
      "\n",
      "Epoch 399, Average loss:1.2231902281443279, Average accuracy:0.8704427083333334\n",
      "  Random 5 prediction: long long associated associated anti anti left left authoritarian left\n",
      "  First 5 prediction: has movement anarchy post post post post post post post\n",
      "\n",
      "Epoch 499, Average loss:0.7656367868185043, Average accuracy:0.9140625\n",
      "  Random 5 prediction: noted mutual stirner warren warren tucker tucker tucker tucker tucker\n",
      "  First 5 prediction: her liberty noted own warren warren tucker tucker tucker tucker\n",
      "\n",
      "Epoch 599, Average loss:1.107410545150439, Average accuracy:0.8756510416666666\n",
      "  Random 5 prediction: syndicalists syndicalists propaganda propaganda propaganda ricardo national national mag mag\n",
      "  First 5 prediction: spanish force working syndicalists syndicalists propaganda propaganda propaganda ricardo national\n",
      "\n",
      "Epoch 699, Average loss:0.9093838532765707, Average accuracy:0.8854166666666666\n",
      "  Random 5 prediction: teachings jesus directive directive antifa antifa official relying relying relying\n",
      "  First 5 prediction: right who within communities tolstoy communities nonviolent official directive christianity\n",
      "\n",
      "Epoch 799, Average loss:0.752622996767362, Average accuracy:0.890625\n",
      "  Random 5 prediction: important generalizations bob hakim hakim hakim hakim hakim hakim hakim\n",
      "  First 5 prediction: individual include associated important important important bey bey bey bey\n",
      "\n",
      "Epoch 899, Average loss:0.41430705537398654, Average accuracy:0.9440104166666666\n",
      "  Random 5 prediction: benjamin egoism tucker tucker tucker tucker tucker tucker tucker tucker\n",
      "  First 5 prediction: self century tucker warren tucker tucker tucker tucker tucker tucker\n",
      "\n",
      "Epoch 999, Average loss:0.33439325789610547, Average accuracy:0.9485677083333334\n",
      "  Random 5 prediction: syndicalists syndicalists t syndicalists syndicalists unity t syndicalists syndicalists management\n",
      "  First 5 prediction: spain century syndicalist self spain syndicalist french spain propaganda management\n"
     ]
    }
   ],
   "source": [
    "n_epochs = 1000\n",
    "learning_rate = 0.001\n",
    "text8.reset_index()\n",
    "n_batches = text8.n_batches_seq(n_tx=n_x,n_ty=n_y)\n",
    "n_epochs_display = 100\n",
    "\n",
    "with tf.Session() as tfs:\n",
    "    tf.global_variables_initializer().run()\n",
    "\n",
    "    for epoch in range(n_epochs):\n",
    "        epoch_loss = 0\n",
    "        epoch_accuracy = 0\n",
    "        for step in range(n_batches):\n",
    "            x_batch, y_batch = text8.next_batch_seq(n_tx=n_x,n_ty=n_y)\n",
    "            y_batch = nputil.to2d(y_batch,unit_axis=1)\n",
    "            y_onehot = np.zeros(shape=[batch_size,text8.vocab_len],dtype=np.float32)\n",
    "            for i in range(batch_size):\n",
    "                y_onehot[i,y_batch[i]]=1\n",
    "            \n",
    "            feed_dict = {x_p: x_batch.reshape(-1, n_x, n_x_vars), y_p: y_onehot}\n",
    "\n",
    "            _, batch_accuracy, batch_loss = tfs.run([optimizer, accuracy, loss], feed_dict=feed_dict)\n",
    "            epoch_loss += batch_loss\n",
    "            epoch_accuracy += batch_accuracy\n",
    "        if (epoch+1) % (n_epochs_display) == 0:\n",
    "            epoch_loss = epoch_loss / n_batches\n",
    "            epoch_accuracy = epoch_accuracy / n_batches\n",
    "            print('\\nEpoch {0:}, Average loss:{1:}, Average accuracy:{2:}'.\n",
    "              format(epoch,epoch_loss,epoch_accuracy ))\n",
    "            \n",
    "            y_pred_r5 = np.empty([10])\n",
    "            y_pred_f5 = np.empty([10])\n",
    "            \n",
    "            x_test_r5 = random5.copy()\n",
    "            x_test_f5 = first5.copy()\n",
    "            # let us generate text of 20 words after feeding 5 words\n",
    "            for i in range(10):\n",
    "                for x,y in zip([x_test_r5,x_test_f5],[y_pred_r5,y_pred_f5]):\n",
    "                    x_input = x.copy()\n",
    "                    feed_dict = {x_p: x_input.reshape(-1, n_x, n_x_vars)}\n",
    "                    y_pred = tfs.run(y_out, feed_dict=feed_dict)\n",
    "                    y_pred_id = int(tf.argmax(y_pred, 1).eval())\n",
    "                    y[i]=y_pred_id\n",
    "                    x[:-1] = x[1:]\n",
    "                    x[-1] = y_pred_id\n",
    "            print('  Random 5 prediction:',id2string(y_pred_r5))\n",
    "            print('  First 5 prediction:',id2string(y_pred_f5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  },
  "toc": {
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": true,
   "toc_cell": true,
   "toc_position": {},
   "toc_section_display": "block",
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
